{
 "cells": [
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "KLikEnWjnm5G",
    "colab_type": "text"
   },
   "source": [
    "# Load & Predict\n",
    "\n",
    "## Download Pretrained Weights"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 1,
   "metadata": {
    "id": "5RSVDVIGngOz",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "!pip install -q keras-bert"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 2,
   "metadata": {
    "id": "xCGEc8M2nvBw",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 204.0
    },
    "outputId": "e1b3fd7c-6a37-41a5-f77d-583e65b8a66e"
   },
   "outputs": [],
   "source": [
    "!wget -q https://storage.googleapis.com/bert_models/2018_11_03/chinese_L-12_H-768_A-12.zip"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 3,
   "metadata": {
    "id": "ImjXK-o9nwDf",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 119.0
    },
    "outputId": "cdc5680e-f4ee-4d78-8c8e-13eb9a702b51"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Archive:  chinese_L-12_H-768_A-12.zip\n",
      "  inflating: chinese_L-12_H-768_A-12/bert_model.ckpt.meta  \n",
      "  inflating: chinese_L-12_H-768_A-12/bert_model.ckpt.data-00000-of-00001  \n",
      "  inflating: chinese_L-12_H-768_A-12/vocab.txt  \n",
      "  inflating: chinese_L-12_H-768_A-12/bert_model.ckpt.index  \n",
      "  inflating: chinese_L-12_H-768_A-12/bert_config.json  \n"
     ]
    }
   ],
   "source": [
    "!unzip -o chinese_L-12_H-768_A-12.zip"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ivWbWgG1n0Y_",
    "colab_type": "text"
   },
   "source": [
    "## Build Model & Dictionary"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "ZNceNEYDn2kN",
    "colab_type": "text"
   },
   "source": [
    "Set paths:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 4,
   "metadata": {
    "id": "VFqFbUxhn1Un",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "import os\n",
    "\n",
    "pretrained_path = 'chinese_L-12_H-768_A-12'\n",
    "config_path = os.path.join(pretrained_path, 'bert_config.json')\n",
    "checkpoint_path = os.path.join(pretrained_path, 'bert_model.ckpt')\n",
    "vocab_path = os.path.join(pretrained_path, 'vocab.txt')"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "L9Vvv4nFn8W_",
    "colab_type": "text"
   },
   "source": [
    "Enable tf.keras by adding TF_KERAS to environment variables:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 5,
   "metadata": {
    "id": "hzQsp1T7n_fa",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "os.environ['TF_KERAS'] = '1'"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "XbW8fyUeoBZ6",
    "colab_type": "text"
   },
   "source": [
    "Build the dictionary & inverse dictionary:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 6,
   "metadata": {
    "id": "wMCnEEYloDiQ",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "import codecs\n",
    "\n",
    "token_dict = {}\n",
    "with codecs.open(vocab_path, 'r', 'utf8') as reader:\n",
    "    for line in reader:\n",
    "        token = line.strip()\n",
    "        token_dict[token] = len(token_dict)\n",
    "token_dict_inv = {v: k for k, v in token_dict.items()}"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "xwrGrzZ9oFzZ",
    "colab_type": "text"
   },
   "source": [
    "Build the model:"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 7,
   "metadata": {
    "id": "G6W-VuSroIKO",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 4525.0
    },
    "outputId": "3abc830c-8438-4127-bbaa-83476f27b68a"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/ops/resource_variable_ops.py:435: colocate_with (from tensorflow.python.framework.ops) is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Colocations handled automatically by placer.\n",
      "WARNING:tensorflow:From /usr/local/lib/python3.6/dist-packages/tensorflow/python/keras/layers/core.py:143: calling dropout (from tensorflow.python.ops.nn_ops) with keep_prob is deprecated and will be removed in a future version.\n",
      "Instructions for updating:\n",
      "Please use `rate` instead of `keep_prob`. Rate should be set to `rate = 1 - keep_prob`.\n",
      "________________________________________________________________________________________________________________________\n",
      "Layer (type)                           Output Shape               Param #       Connected to                            \n",
      "========================================================================================================================\n",
      "Input-Token (InputLayer)               (None, 512)                0                                                     \n",
      "________________________________________________________________________________________________________________________\n",
      "Input-Segment (InputLayer)             (None, 512)                0                                                     \n",
      "________________________________________________________________________________________________________________________\n",
      "Embedding-Token (TokenEmbedding)       [(None, 512, 768), (21128, 16226304      Input-Token[0][0]                       \n",
      "________________________________________________________________________________________________________________________\n",
      "Embedding-Segment (Embedding)          (None, 512, 768)           1536          Input-Segment[0][0]                     \n",
      "________________________________________________________________________________________________________________________\n",
      "Embedding-Token-Segment (Add)          (None, 512, 768)           0             Embedding-Token[0][0]                   \n",
      "                                                                                Embedding-Segment[0][0]                 \n",
      "________________________________________________________________________________________________________________________\n",
      "Embedding-Position (PositionEmbedding) (None, 512, 768)           393216        Embedding-Token-Segment[0][0]           \n",
      "________________________________________________________________________________________________________________________\n",
      "Embedding-Dropout (Dropout)            (None, 512, 768)           0             Embedding-Position[0][0]                \n",
      "________________________________________________________________________________________________________________________\n",
      "Embedding-Norm (LayerNormalization)    (None, 512, 768)           1536          Embedding-Dropout[0][0]                 \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-1-MultiHeadSelfAttention (Mult (None, None, 768)          2362368       Embedding-Norm[0][0]                    \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-1-MultiHeadSelfAttention-Dropo (None, None, 768)          0             Encoder-1-MultiHeadSelfAttention[0][0]  \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-1-MultiHeadSelfAttention-Add ( (None, 512, 768)           0             Embedding-Norm[0][0]                    \n",
      "                                                                                Encoder-1-MultiHeadSelfAttention-Dropout\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-1-MultiHeadSelfAttention-Norm  (None, 512, 768)           1536          Encoder-1-MultiHeadSelfAttention-Add[0][\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-1-FeedForward (FeedForward)    (None, 512, 768)           4722432       Encoder-1-MultiHeadSelfAttention-Norm[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-1-FeedForward-Dropout (Dropout (None, 512, 768)           0             Encoder-1-FeedForward[0][0]             \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-1-FeedForward-Add (Add)        (None, 512, 768)           0             Encoder-1-MultiHeadSelfAttention-Norm[0]\n",
      "                                                                                Encoder-1-FeedForward-Dropout[0][0]     \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-1-FeedForward-Norm (LayerNorma (None, 512, 768)           1536          Encoder-1-FeedForward-Add[0][0]         \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-2-MultiHeadSelfAttention (Mult (None, None, 768)          2362368       Encoder-1-FeedForward-Norm[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-2-MultiHeadSelfAttention-Dropo (None, None, 768)          0             Encoder-2-MultiHeadSelfAttention[0][0]  \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-2-MultiHeadSelfAttention-Add ( (None, 512, 768)           0             Encoder-1-FeedForward-Norm[0][0]        \n",
      "                                                                                Encoder-2-MultiHeadSelfAttention-Dropout\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-2-MultiHeadSelfAttention-Norm  (None, 512, 768)           1536          Encoder-2-MultiHeadSelfAttention-Add[0][\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-2-FeedForward (FeedForward)    (None, 512, 768)           4722432       Encoder-2-MultiHeadSelfAttention-Norm[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-2-FeedForward-Dropout (Dropout (None, 512, 768)           0             Encoder-2-FeedForward[0][0]             \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-2-FeedForward-Add (Add)        (None, 512, 768)           0             Encoder-2-MultiHeadSelfAttention-Norm[0]\n",
      "                                                                                Encoder-2-FeedForward-Dropout[0][0]     \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-2-FeedForward-Norm (LayerNorma (None, 512, 768)           1536          Encoder-2-FeedForward-Add[0][0]         \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-3-MultiHeadSelfAttention (Mult (None, None, 768)          2362368       Encoder-2-FeedForward-Norm[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-3-MultiHeadSelfAttention-Dropo (None, None, 768)          0             Encoder-3-MultiHeadSelfAttention[0][0]  \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-3-MultiHeadSelfAttention-Add ( (None, 512, 768)           0             Encoder-2-FeedForward-Norm[0][0]        \n",
      "                                                                                Encoder-3-MultiHeadSelfAttention-Dropout\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-3-MultiHeadSelfAttention-Norm  (None, 512, 768)           1536          Encoder-3-MultiHeadSelfAttention-Add[0][\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-3-FeedForward (FeedForward)    (None, 512, 768)           4722432       Encoder-3-MultiHeadSelfAttention-Norm[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-3-FeedForward-Dropout (Dropout (None, 512, 768)           0             Encoder-3-FeedForward[0][0]             \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-3-FeedForward-Add (Add)        (None, 512, 768)           0             Encoder-3-MultiHeadSelfAttention-Norm[0]\n",
      "                                                                                Encoder-3-FeedForward-Dropout[0][0]     \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-3-FeedForward-Norm (LayerNorma (None, 512, 768)           1536          Encoder-3-FeedForward-Add[0][0]         \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-4-MultiHeadSelfAttention (Mult (None, None, 768)          2362368       Encoder-3-FeedForward-Norm[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-4-MultiHeadSelfAttention-Dropo (None, None, 768)          0             Encoder-4-MultiHeadSelfAttention[0][0]  \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-4-MultiHeadSelfAttention-Add ( (None, 512, 768)           0             Encoder-3-FeedForward-Norm[0][0]        \n",
      "                                                                                Encoder-4-MultiHeadSelfAttention-Dropout\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-4-MultiHeadSelfAttention-Norm  (None, 512, 768)           1536          Encoder-4-MultiHeadSelfAttention-Add[0][\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-4-FeedForward (FeedForward)    (None, 512, 768)           4722432       Encoder-4-MultiHeadSelfAttention-Norm[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-4-FeedForward-Dropout (Dropout (None, 512, 768)           0             Encoder-4-FeedForward[0][0]             \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-4-FeedForward-Add (Add)        (None, 512, 768)           0             Encoder-4-MultiHeadSelfAttention-Norm[0]\n",
      "                                                                                Encoder-4-FeedForward-Dropout[0][0]     \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-4-FeedForward-Norm (LayerNorma (None, 512, 768)           1536          Encoder-4-FeedForward-Add[0][0]         \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-5-MultiHeadSelfAttention (Mult (None, None, 768)          2362368       Encoder-4-FeedForward-Norm[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-5-MultiHeadSelfAttention-Dropo (None, None, 768)          0             Encoder-5-MultiHeadSelfAttention[0][0]  \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-5-MultiHeadSelfAttention-Add ( (None, 512, 768)           0             Encoder-4-FeedForward-Norm[0][0]        \n",
      "                                                                                Encoder-5-MultiHeadSelfAttention-Dropout\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-5-MultiHeadSelfAttention-Norm  (None, 512, 768)           1536          Encoder-5-MultiHeadSelfAttention-Add[0][\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-5-FeedForward (FeedForward)    (None, 512, 768)           4722432       Encoder-5-MultiHeadSelfAttention-Norm[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-5-FeedForward-Dropout (Dropout (None, 512, 768)           0             Encoder-5-FeedForward[0][0]             \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-5-FeedForward-Add (Add)        (None, 512, 768)           0             Encoder-5-MultiHeadSelfAttention-Norm[0]\n",
      "                                                                                Encoder-5-FeedForward-Dropout[0][0]     \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-5-FeedForward-Norm (LayerNorma (None, 512, 768)           1536          Encoder-5-FeedForward-Add[0][0]         \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-6-MultiHeadSelfAttention (Mult (None, None, 768)          2362368       Encoder-5-FeedForward-Norm[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-6-MultiHeadSelfAttention-Dropo (None, None, 768)          0             Encoder-6-MultiHeadSelfAttention[0][0]  \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-6-MultiHeadSelfAttention-Add ( (None, 512, 768)           0             Encoder-5-FeedForward-Norm[0][0]        \n",
      "                                                                                Encoder-6-MultiHeadSelfAttention-Dropout\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-6-MultiHeadSelfAttention-Norm  (None, 512, 768)           1536          Encoder-6-MultiHeadSelfAttention-Add[0][\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-6-FeedForward (FeedForward)    (None, 512, 768)           4722432       Encoder-6-MultiHeadSelfAttention-Norm[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-6-FeedForward-Dropout (Dropout (None, 512, 768)           0             Encoder-6-FeedForward[0][0]             \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-6-FeedForward-Add (Add)        (None, 512, 768)           0             Encoder-6-MultiHeadSelfAttention-Norm[0]\n",
      "                                                                                Encoder-6-FeedForward-Dropout[0][0]     \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-6-FeedForward-Norm (LayerNorma (None, 512, 768)           1536          Encoder-6-FeedForward-Add[0][0]         \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-7-MultiHeadSelfAttention (Mult (None, None, 768)          2362368       Encoder-6-FeedForward-Norm[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-7-MultiHeadSelfAttention-Dropo (None, None, 768)          0             Encoder-7-MultiHeadSelfAttention[0][0]  \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-7-MultiHeadSelfAttention-Add ( (None, 512, 768)           0             Encoder-6-FeedForward-Norm[0][0]        \n",
      "                                                                                Encoder-7-MultiHeadSelfAttention-Dropout\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-7-MultiHeadSelfAttention-Norm  (None, 512, 768)           1536          Encoder-7-MultiHeadSelfAttention-Add[0][\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-7-FeedForward (FeedForward)    (None, 512, 768)           4722432       Encoder-7-MultiHeadSelfAttention-Norm[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-7-FeedForward-Dropout (Dropout (None, 512, 768)           0             Encoder-7-FeedForward[0][0]             \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-7-FeedForward-Add (Add)        (None, 512, 768)           0             Encoder-7-MultiHeadSelfAttention-Norm[0]\n",
      "                                                                                Encoder-7-FeedForward-Dropout[0][0]     \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-7-FeedForward-Norm (LayerNorma (None, 512, 768)           1536          Encoder-7-FeedForward-Add[0][0]         \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-8-MultiHeadSelfAttention (Mult (None, None, 768)          2362368       Encoder-7-FeedForward-Norm[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-8-MultiHeadSelfAttention-Dropo (None, None, 768)          0             Encoder-8-MultiHeadSelfAttention[0][0]  \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-8-MultiHeadSelfAttention-Add ( (None, 512, 768)           0             Encoder-7-FeedForward-Norm[0][0]        \n",
      "                                                                                Encoder-8-MultiHeadSelfAttention-Dropout\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-8-MultiHeadSelfAttention-Norm  (None, 512, 768)           1536          Encoder-8-MultiHeadSelfAttention-Add[0][\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-8-FeedForward (FeedForward)    (None, 512, 768)           4722432       Encoder-8-MultiHeadSelfAttention-Norm[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-8-FeedForward-Dropout (Dropout (None, 512, 768)           0             Encoder-8-FeedForward[0][0]             \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-8-FeedForward-Add (Add)        (None, 512, 768)           0             Encoder-8-MultiHeadSelfAttention-Norm[0]\n",
      "                                                                                Encoder-8-FeedForward-Dropout[0][0]     \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-8-FeedForward-Norm (LayerNorma (None, 512, 768)           1536          Encoder-8-FeedForward-Add[0][0]         \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-9-MultiHeadSelfAttention (Mult (None, None, 768)          2362368       Encoder-8-FeedForward-Norm[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-9-MultiHeadSelfAttention-Dropo (None, None, 768)          0             Encoder-9-MultiHeadSelfAttention[0][0]  \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-9-MultiHeadSelfAttention-Add ( (None, 512, 768)           0             Encoder-8-FeedForward-Norm[0][0]        \n",
      "                                                                                Encoder-9-MultiHeadSelfAttention-Dropout\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-9-MultiHeadSelfAttention-Norm  (None, 512, 768)           1536          Encoder-9-MultiHeadSelfAttention-Add[0][\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-9-FeedForward (FeedForward)    (None, 512, 768)           4722432       Encoder-9-MultiHeadSelfAttention-Norm[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-9-FeedForward-Dropout (Dropout (None, 512, 768)           0             Encoder-9-FeedForward[0][0]             \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-9-FeedForward-Add (Add)        (None, 512, 768)           0             Encoder-9-MultiHeadSelfAttention-Norm[0]\n",
      "                                                                                Encoder-9-FeedForward-Dropout[0][0]     \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-9-FeedForward-Norm (LayerNorma (None, 512, 768)           1536          Encoder-9-FeedForward-Add[0][0]         \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-10-MultiHeadSelfAttention (Mul (None, None, 768)          2362368       Encoder-9-FeedForward-Norm[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-10-MultiHeadSelfAttention-Drop (None, None, 768)          0             Encoder-10-MultiHeadSelfAttention[0][0] \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-10-MultiHeadSelfAttention-Add  (None, 512, 768)           0             Encoder-9-FeedForward-Norm[0][0]        \n",
      "                                                                                Encoder-10-MultiHeadSelfAttention-Dropou\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-10-MultiHeadSelfAttention-Norm (None, 512, 768)           1536          Encoder-10-MultiHeadSelfAttention-Add[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-10-FeedForward (FeedForward)   (None, 512, 768)           4722432       Encoder-10-MultiHeadSelfAttention-Norm[0\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-10-FeedForward-Dropout (Dropou (None, 512, 768)           0             Encoder-10-FeedForward[0][0]            \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-10-FeedForward-Add (Add)       (None, 512, 768)           0             Encoder-10-MultiHeadSelfAttention-Norm[0\n",
      "                                                                                Encoder-10-FeedForward-Dropout[0][0]    \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-10-FeedForward-Norm (LayerNorm (None, 512, 768)           1536          Encoder-10-FeedForward-Add[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-11-MultiHeadSelfAttention (Mul (None, None, 768)          2362368       Encoder-10-FeedForward-Norm[0][0]       \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-11-MultiHeadSelfAttention-Drop (None, None, 768)          0             Encoder-11-MultiHeadSelfAttention[0][0] \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-11-MultiHeadSelfAttention-Add  (None, 512, 768)           0             Encoder-10-FeedForward-Norm[0][0]       \n",
      "                                                                                Encoder-11-MultiHeadSelfAttention-Dropou\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-11-MultiHeadSelfAttention-Norm (None, 512, 768)           1536          Encoder-11-MultiHeadSelfAttention-Add[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-11-FeedForward (FeedForward)   (None, 512, 768)           4722432       Encoder-11-MultiHeadSelfAttention-Norm[0\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-11-FeedForward-Dropout (Dropou (None, 512, 768)           0             Encoder-11-FeedForward[0][0]            \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-11-FeedForward-Add (Add)       (None, 512, 768)           0             Encoder-11-MultiHeadSelfAttention-Norm[0\n",
      "                                                                                Encoder-11-FeedForward-Dropout[0][0]    \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-11-FeedForward-Norm (LayerNorm (None, 512, 768)           1536          Encoder-11-FeedForward-Add[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-12-MultiHeadSelfAttention (Mul (None, None, 768)          2362368       Encoder-11-FeedForward-Norm[0][0]       \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-12-MultiHeadSelfAttention-Drop (None, None, 768)          0             Encoder-12-MultiHeadSelfAttention[0][0] \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-12-MultiHeadSelfAttention-Add  (None, 512, 768)           0             Encoder-11-FeedForward-Norm[0][0]       \n",
      "                                                                                Encoder-12-MultiHeadSelfAttention-Dropou\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-12-MultiHeadSelfAttention-Norm (None, 512, 768)           1536          Encoder-12-MultiHeadSelfAttention-Add[0]\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-12-FeedForward (FeedForward)   (None, 512, 768)           4722432       Encoder-12-MultiHeadSelfAttention-Norm[0\n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-12-FeedForward-Dropout (Dropou (None, 512, 768)           0             Encoder-12-FeedForward[0][0]            \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-12-FeedForward-Add (Add)       (None, 512, 768)           0             Encoder-12-MultiHeadSelfAttention-Norm[0\n",
      "                                                                                Encoder-12-FeedForward-Dropout[0][0]    \n",
      "________________________________________________________________________________________________________________________\n",
      "Encoder-12-FeedForward-Norm (LayerNorm (None, 512, 768)           1536          Encoder-12-FeedForward-Add[0][0]        \n",
      "________________________________________________________________________________________________________________________\n",
      "MLM-Dense (Dense)                      (None, 512, 768)           590592        Encoder-12-FeedForward-Norm[0][0]       \n",
      "________________________________________________________________________________________________________________________\n",
      "MLM-Norm (LayerNormalization)          (None, 512, 768)           1536          MLM-Dense[0][0]                         \n",
      "________________________________________________________________________________________________________________________\n",
      "Extract (Extract)                      (None, 768)                0             Encoder-12-FeedForward-Norm[0][0]       \n",
      "________________________________________________________________________________________________________________________\n",
      "MLM-Sim (EmbeddingSimilarity)          (None, 512, 21128)         21128         MLM-Norm[0][0]                          \n",
      "                                                                                Embedding-Token[0][1]                   \n",
      "________________________________________________________________________________________________________________________\n",
      "Input-Masked (InputLayer)              (None, 512)                0                                                     \n",
      "________________________________________________________________________________________________________________________\n",
      "NSP-Dense (Dense)                      (None, 768)                590592        Extract[0][0]                           \n",
      "________________________________________________________________________________________________________________________\n",
      "MLM (Masked)                           (None, 512, 21128)         0             MLM-Sim[0][0]                           \n",
      "                                                                                Input-Masked[0][0]                      \n",
      "________________________________________________________________________________________________________________________\n",
      "NSP (Dense)                            (None, 2)                  1538          NSP-Dense[0][0]                         \n",
      "========================================================================================================================\n",
      "Total params: 102,882,442\n",
      "Trainable params: 102,882,442\n",
      "Non-trainable params: 0\n",
      "________________________________________________________________________________________________________________________\n"
     ]
    }
   ],
   "source": [
    "from keras_bert import load_trained_model_from_checkpoint\n",
    "\n",
    "model = load_trained_model_from_checkpoint(config_path, checkpoint_path, training=True)\n",
    "model.summary(line_length=120)"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "FIR2QFEQoWlt",
    "colab_type": "text"
   },
   "source": [
    "## Predict Masked"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 8,
   "metadata": {
    "id": "x3z9_E5LoZFW",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 54.0
    },
    "outputId": "21dcb2a9-c2e7-447e-9197-ed98d8892b29"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokens: ['[CLS]', '[MASK]', '[MASK]', '是', '利', '用', '符', '号', '语', '言', '研', '究', '数', '量', '、', '结', '构', '、', '变', '化', '以', '及', '空', '间', '等', '概', '念', '的', '一', '门', '学', '科', '[SEP]']\n"
     ]
    }
   ],
   "source": [
    "from keras_bert import Tokenizer\n",
    "\n",
    "tokenizer = Tokenizer(token_dict)\n",
    "text = '数学是利用符号语言研究数量、结构、变化以及空间等概念的一门学科'\n",
    "tokens = tokenizer.tokenize(text)\n",
    "tokens[1] = tokens[2] = '[MASK]'\n",
    "print('Tokens:', tokens)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 9,
   "metadata": {
    "id": "PLsL7tsQofYo",
    "colab_type": "code",
    "colab": {}
   },
   "outputs": [],
   "source": [
    "import numpy as np\n",
    "\n",
    "indices = np.array([[token_dict[token] for token in tokens] + [0] * (512 - len(tokens))])\n",
    "segments = np.array([[0] * len(tokens) + [0] * (512 - len(tokens))])\n",
    "masks = np.array([[0, 1, 1] + [0] * (512 - 3)])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 10,
   "metadata": {
    "id": "Ei_RhvPiokxg",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34.0
    },
    "outputId": "1f114d98-f276-4453-c525-7a363737b7d5"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Fill with:  ['数', '学']\n"
     ]
    }
   ],
   "source": [
    "predicts = model.predict([indices, segments, masks])[0].argmax(axis=-1).tolist()\n",
    "print('Fill with: ', list(map(lambda x: token_dict_inv[x], predicts[0][1:3])))"
   ]
  },
  {
   "cell_type": "markdown",
   "metadata": {
    "id": "Oi8_odqjqc9X",
    "colab_type": "text"
   },
   "source": [
    "## Predict Next Sentence"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 11,
   "metadata": {
    "id": "zDEzuid2oh2Z",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 54.0
    },
    "outputId": "76ac1935-506f-48fb-ab55-7bce4aa68084"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokens: ['[CLS]', '数', '学', '是', '利', '用', '符', '号', '语', '言', '研', '究', '數', '量', '、', '结', '构', '、', '变', '化', '以', '及', '空', '间', '等', '概', '念', '的', '一', '門', '学', '科', '。', '[SEP]', '从', '某', '种', '角', '度', '看', '屬', '於', '形', '式', '科', '學', '的', '一', '種', '。', '[SEP]']\n"
     ]
    }
   ],
   "source": [
    "sentence_1 = '数学是利用符号语言研究數量、结构、变化以及空间等概念的一門学科。'\n",
    "sentence_2 = '从某种角度看屬於形式科學的一種。'\n",
    "print('Tokens:', tokenizer.tokenize(first=sentence_1, second=sentence_2))\n",
    "indices, segments = tokenizer.encode(first=sentence_1, second=sentence_2, max_len=512)\n",
    "masks = np.array([[0] * 512])"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 12,
   "metadata": {
    "id": "g2wZ4q5Fqxpx",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34.0
    },
    "outputId": "ba5d4adb-5c34-4687-d57b-cbc5fabe7102"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "从某种角度看屬於形式科學的一種。 is random next:  False\n"
     ]
    }
   ],
   "source": [
    "predicts = model.predict([np.array([indices]), np.array([segments]), masks])[1]\n",
    "print('%s is random next: ' % sentence_2, bool(np.argmax(predicts, axis=-1)[0]))"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 13,
   "metadata": {
    "id": "pFlpja57qzCN",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 54.0
    },
    "outputId": "c33a036c-2bb4-4457-89ab-e3b9f2ac49f9"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "Tokens: ['[CLS]', '数', '学', '是', '利', '用', '符', '号', '语', '言', '研', '究', '數', '量', '、', '结', '构', '、', '变', '化', '以', '及', '空', '间', '等', '概', '念', '的', '一', '門', '学', '科', '。', '[SEP]', '任', '何', '一', '个', '希', '尔', '伯', '特', '空', '间', '都', '有', '一', '族', '标', '准', '正', '交', '基', '。', '[SEP]']\n"
     ]
    }
   ],
   "source": [
    "sentence_2 = '任何一个希尔伯特空间都有一族标准正交基。'\n",
    "print('Tokens:', tokenizer.tokenize(first=sentence_1, second=sentence_2))\n",
    "indices, segments = tokenizer.encode(first=sentence_1, second=sentence_2, max_len=512)"
   ]
  },
  {
   "cell_type": "code",
   "execution_count": 14,
   "metadata": {
    "id": "1XrXwEvgq1BU",
    "colab_type": "code",
    "colab": {
     "base_uri": "https://localhost:8080/",
     "height": 34.0
    },
    "outputId": "6857023f-fedc-4f94-b929-ee09c1650f2b"
   },
   "outputs": [
    {
     "name": "stdout",
     "output_type": "stream",
     "text": [
      "任何一个希尔伯特空间都有一族标准正交基。 is random next:  True\n"
     ]
    }
   ],
   "source": [
    "predicts = model.predict([np.array([indices]), np.array([segments]), masks])[1]\n",
    "print('%s is random next: ' % sentence_2, bool(np.argmax(predicts, axis=-1)[0]))"
   ]
  }
 ],
 "metadata": {
  "colab": {
   "name": "keras_bert_load_and_predict.ipynb",
   "version": "0.3.2",
   "provenance": [],
   "collapsed_sections": []
  },
  "kernelspec": {
   "name": "python3",
   "display_name": "Python 3"
  }
 },
 "nbformat": 4,
 "nbformat_minor": 0
}
